{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 🧪 ADK with A2A Application Testing\n",
    "\n",
    "This notebook demonstrates how to test an ADK (Agent Development Kit) application that implements the Agent2Agent (A2A) protocol.\n",
    "It covers both local and remote testing, both with Agent Engine and Cloud Run."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set Up Your Environment\n",
    "\n",
    "> **Note:** For best results, use the same `.venv` created for local development with `uv` to ensure dependency compatibility and avoid environment-related issues."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Uncomment the following lines if you're not using the virtual environment created by uv\n",
    "# import sys\n",
    "\n",
    "# sys.path.append(\"../\")\n",
    "# !pip install google-cloud-aiplatform a2a-sdk --upgrade"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ruff: noqa\n",
    "import asyncio\n",
    "import json\n",
    "import os\n",
    "import requests\n",
    "import uuid\n",
    "\n",
    "import vertexai\n",
    "from a2a.types import (\n",
    "    Message,\n",
    "    MessageSendParams,\n",
    "    Part,\n",
    "    Role,\n",
    "    SendStreamingMessageRequest,\n",
    "    TextPart,\n",
    ")\n",
    "from IPython.display import Markdown, display\n",
    "from google.adk.artifacts import InMemoryArtifactService\n",
    "from google.adk.sessions import InMemorySessionService\n",
    "\n",
    "from app.agent_engine_app import AgentEngineApp\n",
    "from tests.helpers import (\n",
    "    build_get_request,\n",
    "    build_post_request,\n",
    "    poll_task_completion,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialize Vertex AI Client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the Vertex AI client\n",
    "LOCATION = \"us-central1\"\n",
    "\n",
    "client = vertexai.Client(\n",
    "    location=LOCATION,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## If you are using Agent Engine\n",
    "See more documentation at [Agent Engine Overview](https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/overview)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Remote Testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set to None to auto-detect from ./deployment_metadata.json, or specify manually\n",
    "# \"projects/PROJECT_ID/locations/us-central1/reasoningEngines/ENGINE_ID\"\n",
    "REASONING_ENGINE_ID = None\n",
    "\n",
    "if REASONING_ENGINE_ID is None:\n",
    "    try:\n",
    "        with open(\"../deployment_metadata.json\") as f:\n",
    "            metadata = json.load(f)\n",
    "            REASONING_ENGINE_ID = metadata.get(\"remote_agent_engine_id\")\n",
    "    except (FileNotFoundError, json.JSONDecodeError):\n",
    "        pass\n",
    "\n",
    "print(f\"Using REASONING_ENGINE_ID: {REASONING_ENGINE_ID}\")\n",
    "\n",
    "# Extract project_id, location, and engine_id from REASONING_ENGINE_ID\n",
    "parts = REASONING_ENGINE_ID.split(\"/\")\n",
    "project_id = parts[1]\n",
    "location = parts[3]\n",
    "engine_id = parts[5]\n",
    "\n",
    "# Construct API endpoints\n",
    "base_url = f\"https://{location}-aiplatform.googleapis.com\"\n",
    "a2a_base_path = f\"/v1beta1/projects/{project_id}/locations/{location}/reasoningEngines/{engine_id}/a2a/v1\"\n",
    "\n",
    "print(f\"Base URL: {base_url}\")\n",
    "print(f\"A2A base path: {a2a_base_path}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Fetch Agent Card"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fetch agent card using REST API\n",
    "import google.auth\n",
    "import google.auth.transport.requests\n",
    "\n",
    "# Get authentication token\n",
    "creds, project = google.auth.default()\n",
    "auth_req = google.auth.transport.requests.Request()\n",
    "creds.refresh(auth_req)\n",
    "\n",
    "headers = {\"Content-Type\": \"application/json\", \"Authorization\": f\"Bearer {creds.token}\"}\n",
    "\n",
    "# GET request to fetch agent card\n",
    "response = requests.get(\n",
    "    f\"{base_url}{a2a_base_path}/card\",\n",
    "    headers=headers,\n",
    ")\n",
    "\n",
    "print(f\"Response status code: {response.status_code}\")\n",
    "\n",
    "if response.status_code == 200:\n",
    "    remote_a2a_agent_card = response.json()\n",
    "    print(f\"Agent: {remote_a2a_agent_card.get('name')}\")\n",
    "    print(f\"URL: {remote_a2a_agent_card.get('url')}\")\n",
    "    print(\n",
    "        f\"Skills: {[s.get('description') for s in remote_a2a_agent_card.get('skills', [])]}\"\n",
    "    )\n",
    "    print(f\"Protocol Version: {remote_a2a_agent_card.get('protocolVersion')}\")\n",
    "else:\n",
    "    print(f\"Error: {response.text}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Send Message"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Send the message using A2A REST API\n",
    "import google.auth\n",
    "import google.auth.transport.requests\n",
    "\n",
    "# Get authentication token\n",
    "creds, project = google.auth.default()\n",
    "auth_req = google.auth.transport.requests.Request()\n",
    "creds.refresh(auth_req)\n",
    "\n",
    "headers = {\"Content-Type\": \"application/json\", \"Authorization\": f\"Bearer {creds.token}\"}\n",
    "\n",
    "data = {\n",
    "    \"message\": {\n",
    "        \"messageId\": f\"msg-{os.urandom(8).hex()}\",\n",
    "        \"content\": [{\"text\": \"What is the weather in New York?\"}],\n",
    "        \"role\": \"ROLE_USER\",\n",
    "    }\n",
    "}\n",
    "\n",
    "# Send POST request to message:send endpoint\n",
    "response = requests.post(\n",
    "    f\"{base_url}{a2a_base_path}/message:send\",\n",
    "    headers=headers,\n",
    "    json=data,\n",
    ")\n",
    "\n",
    "print(f\"Response status code: {response.status_code}\")\n",
    "\n",
    "if response.status_code == 200:\n",
    "    response_data = response.json()\n",
    "    task_id = response_data[\"task\"][\"id\"]\n",
    "    print(f\"Task started: {task_id}\")\n",
    "else:\n",
    "    print(f\"Error: {response.text}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Poll for response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Poll for task completion using REST API\n",
    "max_attempts = 30\n",
    "for attempt in range(max_attempts):\n",
    "    poll_response = requests.get(\n",
    "        f\"{base_url}{a2a_base_path}/tasks/{task_id}\",\n",
    "        headers=headers,\n",
    "    )\n",
    "\n",
    "    if poll_response.status_code != 200:\n",
    "        print(f\"Poll failed with status code: {poll_response.status_code}\")\n",
    "        break\n",
    "\n",
    "    result = poll_response.json()\n",
    "    task_state = result.get(\"status\", {}).get(\"state\")\n",
    "    print(f\"Attempt {attempt + 1}: {task_state}\")\n",
    "\n",
    "    if task_state == \"TASK_STATE_COMPLETED\":\n",
    "        print(\"Task completed!\")\n",
    "        break\n",
    "    elif task_state in [\"TASK_STATE_FAILED\", \"TASK_STATE_CANCELLED\"]:\n",
    "        print(f\"Task failed: {result}\")\n",
    "        break\n",
    "\n",
    "    await asyncio.sleep(1)\n",
    "\n",
    "# Extract and display artifacts\n",
    "if \"artifacts\" in result and result[\"artifacts\"]:\n",
    "    for artifact in result[\"artifacts\"]:\n",
    "        if artifact.get(\"parts\"):\n",
    "            for part in artifact[\"parts\"]:\n",
    "                if \"text\" in part:\n",
    "                    display(Markdown(f\"**Answer**:\\\\n {part['text']}\"))\n",
    "                else:\n",
    "                    print(\"Could not extract text from artifact parts.\")\n",
    "else:\n",
    "    print(\"No artifacts found in result\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Local Testing\n",
    "\n",
    "You can import directly the AgentEngineApp class within your environment. \n",
    "To run the agent locally, follow these steps:\n",
    "1. Make sure all required packages are installed in your environment\n",
    "2. The recommended approach is to use the same virtual environment created by the 'uv' tool\n",
    "3. You can set up this environment by running 'make install' from your agent's root directory\n",
    "4. Then select this kernel (.venv folder in your project) in your Jupyter notebook to ensure all dependencies are available"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from app.agent_engine_app import agent_engine\n",
    "\n",
    "agent_engine.set_up()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Verify Custom Method is Registered"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = agent_engine.register_operations()\n",
    "print(test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Fetch Agent Card"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "request = build_get_request(None)\n",
    "response = await agent_engine.handle_authenticated_agent_card(\n",
    "    request=request, context=None\n",
    ")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Send Message"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "message_data = {\n",
    "    \"message\": {\n",
    "        \"messageId\": f\"msg-{os.urandom(8).hex()}\",\n",
    "        \"content\": [{\"text\": \"What is the weather in New York?\"}],\n",
    "        \"role\": \"ROLE_USER\",\n",
    "    },\n",
    "}\n",
    "\n",
    "request = build_post_request(message_data)\n",
    "\n",
    "response = await agent_engine.on_message_send(request=request, context=None)\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Poll for response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task_id = response[\"task\"][\"id\"]\n",
    "print(f\"The Task ID is: {task_id}\")\n",
    "\n",
    "# Poll for completion using helper\n",
    "final_response = await poll_task_completion(agent_engine, task_id)\n",
    "\n",
    "# Extract and display artifacts\n",
    "for artifact in final_response[\"artifacts\"]:\n",
    "    if artifact[\"parts\"] and \"text\" in artifact[\"parts\"][0]:\n",
    "        display(Markdown(f\"**Answer**:\\n {artifact['parts'][0]['text']}\"))\n",
    "    else:\n",
    "        print(\"Could not extract text from artifact parts.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Register Feedback"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "agent_engine.register_feedback(\n    feedback={\n        \"score\": 5,\n        \"text\": \"Great response!\",\n        \"user_id\": \"test-user-123\",\n        \"session_id\": \"test-session-123\",\n    }\n)"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## If you are using Cloud Run"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Remote Testing\n",
    "\n",
    "For more information about authenticating HTTPS requests to Cloud Run services, see:\n",
    "[Cloud Run Authentication Documentation](https://cloud.google.com/run/docs/triggering/https-request)\n",
    "\n",
    "Remote testing involves using a deployed service URL instead of localhost.\n",
    "\n",
    "Authentication is handled using GCP identity tokens instead of local credentials."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ID_TOKEN = get_ipython().getoutput(\"gcloud auth print-identity-token -q\")[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SERVICE_URL = \"YOUR_SERVICE_URL_HERE\"  # Replace with your Cloud Run service URL"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Send a message using A2A protocol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create A2A message request\n",
    "message = Message(\n",
    "    message_id=f\"msg-user-{uuid.uuid4()}\",\n",
    "    role=Role.user,\n",
    "    parts=[Part(root=TextPart(text=\"Hello! Weather in New York?\"))],\n",
    ")\n",
    "\n",
    "request = SendStreamingMessageRequest(\n",
    "    id=f\"req-{uuid.uuid4()}\",\n",
    "    params=MessageSendParams(message=message),\n",
    ")\n",
    "\n",
    "# Set up headers with authentication\n",
    "headers = {\"Content-Type\": \"application/json\", \"Authorization\": f\"Bearer {ID_TOKEN}\"}\n",
    "\n",
    "# Send the streaming request to the A2A endpoint\n",
    "response = requests.post(\n",
    "    f\"{SERVICE_URL}/a2a/app\",\n",
    "    headers=headers,\n",
    "    json=request.model_dump(mode=\"json\", exclude_none=True),\n",
    "    stream=True,\n",
    "    timeout=60,\n",
    ")\n",
    "\n",
    "print(f\"Response status code: {response.status_code}\")\n",
    "\n",
    "# Parse streaming A2A responses\n",
    "for line in response.iter_lines():\n",
    "    if line:\n",
    "        line_str = line.decode(\"utf-8\")\n",
    "        if line_str.startswith(\"data: \"):\n",
    "            event_json = line_str[6:]\n",
    "            event = json.loads(event_json)\n",
    "            print(f\"Received event: {event}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Local Testing\n",
    "\n",
    "> You can run the application locally via the `make local-backend` command.\n",
    "\n",
    "Send a message to the local backend service using the A2A protocol and receive a streaming response."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create A2A message request\n",
    "message = Message(\n",
    "    message_id=f\"msg-user-{uuid.uuid4()}\",\n",
    "    role=Role.user,\n",
    "    parts=[Part(root=TextPart(text=\"Hello! Weather in New York?\"))],\n",
    ")\n",
    "\n",
    "request = SendStreamingMessageRequest(\n",
    "    id=f\"req-{uuid.uuid4()}\",\n",
    "    params=MessageSendParams(message=message),\n",
    ")\n",
    "\n",
    "# Set up headers\n",
    "headers = {\"Content-Type\": \"application/json\"}\n",
    "\n",
    "# Send the streaming request to the local A2A endpoint\n",
    "response = requests.post(\n",
    "    \"http://127.0.0.1:8000/a2a/app\",\n",
    "    headers=headers,\n",
    "    json=request.model_dump(mode=\"json\", exclude_none=True),\n",
    "    stream=True,\n",
    "    timeout=60,\n",
    ")\n",
    "\n",
    "print(f\"Response status code: {response.status_code}\")\n",
    "\n",
    "# Parse streaming A2A responses\n",
    "for line in response.iter_lines():\n",
    "    if line:\n",
    "        line_str = line.decode(\"utf-8\")\n",
    "        if line_str.startswith(\"data: \"):\n",
    "            event_json = line_str[6:]\n",
    "            event = json.loads(event_json)\n",
    "            print(f\"Received event: {event}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "myagent-1761660603",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}